from typing import Dict, Tuple
from typing import Dict, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
from recbole.utils import InputType
from recbole_cdr.data.dataset import CrossDomainDataset
from recbole_cdr.model.crossdomain_recommender import CrossDomainRecommender
from recbole.model.init import xavier_normal_initialization
from recbole.model.loss import EmbLoss, BPRLoss

class ReverseLayerF(Function):
    """Gradient-reversal layer (GRL) for domain adversarial learning."""

    @staticmethod
    def forward(ctx, x: torch.Tensor, alpha: float):
        ctx.alpha = alpha
        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output: torch.Tensor):
        return grad_output.neg() * ctx.alpha, None


class Encoder(nn.Module):
    """Two-layer MLP encoder used for shared / private representation."""

    def __init__(self, in_dim: int, out_dim: int):
        super().__init__()
        self.fc1 = nn.Linear(in_dim, out_dim, bias=False)
        self.fc2 = nn.Linear(out_dim, out_dim, bias=False)
        self.act = nn.ReLU()

    def forward(self, x):
        return self.act(self.fc2(self.act(self.fc1(x))))


class TaskHead(nn.Module):
    """Two-layer MLP for CTR (binary) prediction. Outputs logits."""

    def __init__(self, in_dim: int):
        super().__init__()
        self.linear1 = nn.Linear(in_dim, in_dim, bias=False)
        self.linear2 = nn.Linear(in_dim, 1, bias=False)

    def forward(self, x):
        return self.linear2(self.linear1(x))


class C2DR(CrossDomainRecommender):
    r"""Robust Cross-Domain Recommendation based on Causal Disentanglement (C2DR)."""

    def __init__(self, config, dataset: CrossDomainDataset):
        super().__init__(config, dataset)

        # Loss type handling
        if config['loss_type'] == 'CE':
            self.input_type = InputType.POINTWISE
            self.loss = nn.BCEWithLogitsLoss()
        elif config['loss_type'] == 'BPR':
            self.input_type = InputType.PAIRWISE
            self.loss = BPRLoss()

        # Hyper-params
        self.embedding_size = config["embedding_size"]
        self.alpha_schedule = config["alpha_schedule"]
        self.lambda_confusion = config["lambda_confusion"]
        self.lambda_vec = config["lambda_vec"]
        self.lambda_reg = config["lambda_reg"]
        self.stage = 1  # Default start at Stage-1

        # Embeddings
        self.user_embedding = nn.Embedding(self.total_num_users, self.embedding_size)
        self.item_embedding = nn.Embedding(self.total_num_items, self.embedding_size)

        # Encoders
        self.encoder_shared = Encoder(self.embedding_size, self.embedding_size)
        self.encoder_userA = Encoder(self.embedding_size, self.embedding_size)  # Target-domain private
        self.encoder_userB = Encoder(self.embedding_size, self.embedding_size)  # Source-domain private
        self.encoder_itemA = Encoder(self.embedding_size, self.embedding_size)  # Target item encoder
        self.encoder_itemB = Encoder(self.embedding_size, self.embedding_size)  # Source item encoder

        # Domain classifier (GRL)
        self.domain_clf = nn.Sequential(
            nn.Linear(self.embedding_size, 1),
            nn.Sigmoid()
        )

        # Prediction heads (output logits)
        self.task_a = TaskHead(self.embedding_size * 2)  # Target CTR
        self.task_b = TaskHead(self.embedding_size * 2)  # Source CTR

        # Auxiliary (self-supervised) heads
        self.task_sa = TaskHead(self.embedding_size * 2)  # Shared → tgt item
        self.task_sb = TaskHead(self.embedding_size * 2)  # Shared → src item

        # Other stuff
        self.domain_loss = nn.BCELoss()  # Independent BCE for domain confusion
        self.reg_loss = EmbLoss()  # Embedding regularization
        self._init_weights()
        self.register_buffer("p", torch.zeros(1))  # GRL progress
        self.register_parameter("omega", None)  # Learnable sample weight Ω

    def _init_weights(self):
        self.apply(xavier_normal_initialization)

    def _grl_alpha(self):
        k = self.alpha_schedule.get("k", 10.0)
        p = self.p.clamp_(0, 1).item()
        return 2.0 / (1 + np.exp(-k * p)) - 1.0

    def _encode(self, u_emb, i_emb, domain: str):
        """Return shared-rep, domain-private rep, item-side rep."""
        h_s = self.encoder_shared(u_emb)
        if domain == "source":
            h_d = self.encoder_userB(u_emb)
            h_i = self.encoder_itemB(i_emb)
        else:  # target
            h_d = self.encoder_userA(u_emb)
            h_i = self.encoder_itemA(i_emb)
        return h_s, h_d, h_i

    @staticmethod
    def _vec_mse(h1: torch.Tensor, h2: torch.Tensor, weight: torch.Tensor = None):
        """Vector orthogonal MSE (Eq.9), with optional weight."""
        bs, d = h1.size()
        if weight is None:
            weight = torch.ones(bs, 1, device=h1.device) / bs
        else:
            weight = weight / weight.sum()  # Normalize
        cov12 = (h1.t() @ (weight * h2))  # [d, d]
        mu1 = (weight * h1).sum(0, keepdim=True)  # [1, d]
        mu2 = (weight * h2).sum(0, keepdim=True)  # [1, d]
        cov_mu = mu1.t() @ mu2  # [d, d]
        return F.mse_loss(cov12, cov_mu)

    def calculate_loss(self, interaction):
        su, si, sy = (interaction[self.SOURCE_USER_ID],
                      interaction[self.SOURCE_ITEM_ID],
                      interaction[self.SOURCE_LABEL].float())
        tu, ti, ty = (interaction[self.TARGET_USER_ID],
                      interaction[self.TARGET_ITEM_ID],
                      interaction[self.TARGET_LABEL].float())

        min_bs = min(su.size(0), tu.size(0))
        su, si, sy = su[:min_bs], si[:min_bs], sy[:min_bs]
        tu, ti, ty = tu[:min_bs], ti[:min_bs], ty[:min_bs]

        u_src, i_src = self.user_embedding(su), self.item_embedding(si)
        u_tgt, i_tgt = self.user_embedding(tu), self.item_embedding(ti)

        hs, hs_d, hi_s = self._encode(u_src, i_src, 'source')
        ht, ht_d, hi_t = self._encode(u_tgt, i_tgt, 'target')

        pred_s = self.task_b(torch.cat([hs + hs_d, hi_s], dim=-1)).squeeze(-1)
        pred_t = self.task_a(torch.cat([ht + ht_d, hi_t], dim=-1)).squeeze(-1)

        loss_s = self.loss(pred_s, sy)
        loss_t = self.loss(pred_t, ty)

        pred_sa = self.task_sa(torch.cat([ht, hi_t], dim=-1)).squeeze(-1)
        pred_sb = self.task_sb(torch.cat([hs, hi_s], dim=-1)).squeeze(-1)
        loss_aux = self.loss(pred_sa, ty) + self.loss(pred_sb, sy)

        alpha = self._grl_alpha()
        self.p += 1.0 / 15000
        d_s = self.domain_clf(ReverseLayerF.apply(hs, alpha)).squeeze(-1)
        d_t = self.domain_clf(ReverseLayerF.apply(ht, alpha)).squeeze(-1)
        loss_conf = self.domain_loss(d_s, torch.zeros_like(d_s)) + \
                    self.domain_loss(d_t, torch.ones_like(d_t))

        if self.omega is None or self.omega.size(0) != min_bs:
            self.omega = nn.Parameter(torch.ones(min_bs, 1, device=hs.device))
        omega = F.softmax(self.omega, dim=0)

        loss_orth = self._vec_mse(hs, hs_d, omega) + \
                    self._vec_mse(ht, ht_d, omega) + \
                    self._vec_mse(hs_d, ht_d, omega)

        if self.stage == 1:
            total_loss = loss_s + loss_t + loss_aux + self.lambda_confusion * loss_conf
        elif self.stage == 2:
            total_loss = self.lambda_vec * loss_orth
        else:  # stage-3
            loss_s = (omega * self.loss(pred_s, sy, reduction='none')).mean()
            loss_t = (omega * self.loss(pred_t, ty, reduction='none')).mean()
            total_loss = loss_s + loss_t

#        total_loss += self.lambda_reg * self.reg_loss(self.user_embedding.weight, self.item_embedding.weight)

        return total_loss

    @torch.no_grad()
    def predict(self, interaction):
        u = interaction[self.TARGET_USER_ID]
        i = interaction[self.TARGET_ITEM_ID]
        u_e, i_e = self.user_embedding(u), self.item_embedding(i)
        h_s, h_d, h_i = self._encode(u_e, i_e, "target")
        return self.task_a(torch.cat([h_s + h_d, h_i], dim=-1)).squeeze(-1)

    @torch.no_grad()
    def full_sort_predict(self, interaction):
        user_idx = interaction[self.TARGET_USER_ID]
        user_emb = self.user_embedding(user_idx)  # [bs, embedding_size]

        h_s = self.encoder_shared(user_emb)  # [bs, embedding_size]
        h_d = self.encoder_userA(user_emb)  # [bs, embedding_size]
        user_part = h_s + h_d  # [bs, embedding_size]

        all_items_emb = self.item_embedding.weight[:self.target_num_items]  # [n, embedding_size]
        item_part = self.encoder_itemA(all_items_emb)  # [n, embedding_size]

        w = (self.task_a.linear2.weight @ self.task_a.linear1.weight).squeeze(0)  # [2*embedding_size]
        w_user = w[:self.embedding_size]  # [embedding_size]
        w_item = w[self.embedding_size:]  # [embedding_size]

        user_scores = (user_part @ w_user).view(-1, 1)  # [bs, 1]
        item_scores = (item_part @ w_item).view(1, -1)  # [1, n]

        scores = user_scores + item_scores  # [bs, n]
        return scores.view(-1)  # [bs * n]

    def set_phase(self, phase: str):
        mapping = {
            "train1": 1, "train2": 2, "train3": 3,
            "BOTH": 1, "SOURCE": 2, "TARGET": 3,
            "both": 1, "source": 2, "target": 3,
        }
        if phase in mapping:
            self.stage = mapping[phase]
            self.p.zero_()  # Reset GRL progress
        else:
            if hasattr(self, "logger"):
                self.logger.warning(f"C2DR.set_phase: unknown phase \"{phase}\", keep stage={self.stage}.")